--- title: Training keywords: fastai sidebar: home_sidebar summary: "Notebook to train deep learning models or ensembles for segmentation of fluorescent labels in microscopy images." description: "Notebook to train deep learning models or ensembles for segmentation of fluorescent labels in microscopy images." nb_path: "nbs/_trash/_old/train-Copy1.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
#@markdown Please run this cell to get started.
%load_ext autoreload
%autoreload 2
try:
    from google.colab import files, drive
except ImportError:
    pass
try:
    import deepflash2
except ImportError:
    !pip install -q deepflash2
import zipfile
import shutil
import imageio
from sklearn.model_selection import KFold, train_test_split
from fastai.vision.all import *
from deepflash2.all import *
from deepflash2.data import _read_msk
from scipy.stats import entropy
{% endraw %}

Provide Training Data

{% raw %}
path = Path('sample_data_cFOS')
url = "https://github.com/matjesg/deepflash2/releases/download/model_library/wue1_cFOS_small.zip"
urllib.request.urlretrieve(url, 'sample_data_cFOS.zip')
unzip(path, 'sample_data_cFOS.zip')
{% endraw %}

Check and load data

{% raw %}
image_folder = "images" #@param {type:"string"}
mask_folder = "masks" #@param {type:"string"}
mask_suffix = "_mask.png" #@param {type:"string"}
#@markdown Number of classes: e.g., 2 for binary segmentation (foreground and background class)
n_classes = 2 #@param {type:"integer"}
#@markdown Check if you are providing instance labels (class-aware and instance-aware)
instance_labels = False #@param {type:"boolean"}

f_names = get_image_files(path/image_folder)
label_fn = lambda o: path/mask_folder/f'{o.stem}{mask_suffix}'
#Check if corresponding masks exist
mask_check = [os.path.isfile(label_fn(x)) for x in f_names]
if len(f_names)==sum(mask_check) and len(f_names)>0:
    print(f'Found {len(f_names)} images and {sum(mask_check)} masks in "{path}".')
else:
    print(f'IMAGE/MASK MISMATCH! Found {len(f_names)} images and {sum(mask_check)} masks in "{path}".')
    print('Please check the steps above.')
Found 5 images and 5 masks in "sample_data_cFOS".
{% endraw %} {% raw %}
el = EnsembleLearner(f_names, label_fn, arch='unext50_deepflash2')
#el = EnsembleLearner(f_names, label_fn, arch='unet_deepflash2')
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
el.set_n(1)
{% endraw %} {% raw %}
el.ds_kwargs
{'tile_shape': (518, 518), 'padding': (126, 126)}
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
el.fit_ensemble(20, n_jobs=1, pre_ssl=True)
Creating weights for 0001.png
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/media/data/deepflash2/deepflash2/data.py in _create_weights(self, n_jobs, verbose)
    359             try:
--> 360                 lbl, wgt, pdf = _get_cached_data(self._cache_fn(file.name))
    361                 if not using_cache:

NameError: name 'file' is not defined

During handling of the above exception, another exception occurred:

KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-11-b3806508d497> in <module>
----> 1 el.fit_ensemble(20, n_jobs=1, pre_ssl=True)

/media/data/deepflash2/deepflash2/learner.py in fit_ensemble(self, epochs, skip, **kwargs)
    145         for i in range(1, self.n+1):
    146             if skip and (i in self.models): continue
--> 147             self.fit(i, epochs,  **kwargs)
    148 
    149     def set_n(self, n):

/media/data/deepflash2/deepflash2/learner.py in fit(self, i, epochs, lr_max, bs, n_jobs, verbose, **kwargs)
    126         name = self.ensemble_dir/f'{self.arch}_model-{i}.pth'
    127         files_train, files_val = self.splits[i]
--> 128         train_ds = RandomTileDataset(files_train, label_fn=self.label_fn, n_jobs=n_jobs, verbose=verbose, **self.mw_kwargs, **self.ds_kwargs)
    129         valid_ds = TileDataset(files_val, label_fn=self.label_fn, n_jobs=n_jobs, verbose=verbose, **self.mw_kwargs,**self.ds_kwargs)
    130         batch_tfms = Normalize.from_stats(*self.stats)

/media/data/deepflash2/deepflash2/data.py in __init__(self, sample_mult, flip, rotation_range_deg, deformation_grid, deformation_magnitude, value_minimum_range, value_maximum_range, value_slope_range, *args, **kwargs)
    434     def __init__(self, *args, sample_mult=None, flip=True, rotation_range_deg=(0, 360), deformation_grid=(150, 150), deformation_magnitude=(10, 10),
    435                  value_minimum_range=(0, 0), value_maximum_range=(1, 1), value_slope_range=(1, 1), **kwargs):
--> 436         super().__init__(*args, **kwargs)
    437         store_attr('sample_mult, flip, rotation_range_deg, deformation_grid, deformation_magnitude, value_minimum_range, \
    438                     value_maximum_range, value_slope_range')

/media/data/deepflash2/deepflash2/data.py in __init__(self, files, label_fn, create_weights, instance_labels, n_classes, divide, ignore, tile_shape, padding, preproc_dir, bws, fds, bwf, fbr, n_jobs, verbose, **kwargs)
    333             else: self.preproc_dir = Path(preproc_dir)
    334             self.preproc_dir.mkdir(exist_ok=True, parents=True)
--> 335             if create_weights: self._create_weights(n_jobs, verbose)
    336 
    337     def _cache_fn(self, o):

/media/data/deepflash2/deepflash2/data.py in _create_weights(self, n_jobs, verbose)
    365                 if n_jobs==1:
    366                     if verbose>0: print('Creating weights for', f.name)
--> 367                     self._preproc(f)
    368                 else:
    369                     preproc_queue.append(f)

/media/data/deepflash2/deepflash2/data.py in _preproc(self, file)
    349             instlabels = None
    350         ign = self.ignore[file.name] if file.name in self.ignore else None
--> 351         lbl, wgt, pdf = calculate_weights(clabels, instlabels, ignore=ign, n_dims=self.c,
    352                                           bws=self.bws, fds=self.fds, bwf=self.bwf, fbr=self.fbr)
    353         np.savez_compressed(self._cache_fn(file.name), lbl=lbl, wgt=wgt, pdf=pdf)

/media/data/deepflash2/deepflash2/data.py in calculate_weights(clabels, instlabels, ignore, n_dims, bws, fds, bwf, fbr)
    179             dt = ndimage.morphology.distance_transform_edt(instlabels != instance)
    180 
--> 181             frgrd_dist += np.exp(-dt ** 2 / (2*fds ** 2))
    182             min2dist = np.minimum(min2dist, dt)
    183             newMin1 = np.minimum(min1dist, min2dist)

KeyboardInterrupt: 
{% endraw %} {% raw %}
el.show_validation_results()
Using cache found in /media/data/home/mag01ud/.cache/torch/hub/matjesg_deepflash2_master
{% endraw %} {% raw %}
m=UneXt50()
m
Using cache found in /media/data/home/mag01ud/.cache/torch/hub/facebookresearch_semi-supervised-ImageNet1K-models_master
UneXt50(
  (enc0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (enc1): Sequential(
    (0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
  )
  (enc2): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (enc3): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (4): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (5): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (enc4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (aspp): ASPP(
    (aspps): ModuleList(
      (0): _ASPPModule(
        (atrous_conv): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
      (1): _ASPPModule(
        (atrous_conv): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
      (2): _ASPPModule(
        (atrous_conv): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), groups=4, bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
      (3): _ASPPModule(
        (atrous_conv): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), groups=4, bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
      (4): _ASPPModule(
        (atrous_conv): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), groups=4, bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
    )
    (global_pool): Sequential(
      (0): AdaptiveMaxPool2d(output_size=(1, 1))
      (1): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
    )
    (out_conv): Sequential(
      (0): Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (conv1): Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (drop_aspp): Dropout2d(p=0.5, inplace=False)
  (dec4): UnetBlock(
    (shuf): PixelShuffle_ICNR(
      (0): ConvLayer(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
      )
      (1): PixelShuffle(upscale_factor=2)
    )
    (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): ConvLayer(
      (0): Conv2d(1280, 256, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (conv2): ConvLayer(
      (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (relu): ReLU(inplace=True)
  )
  (dec3): UnetBlock(
    (shuf): PixelShuffle_ICNR(
      (0): ConvLayer(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
      )
      (1): PixelShuffle(upscale_factor=2)
    )
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): ConvLayer(
      (0): Conv2d(640, 128, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (conv2): ConvLayer(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (relu): ReLU(inplace=True)
  )
  (dec2): UnetBlock(
    (shuf): PixelShuffle_ICNR(
      (0): ConvLayer(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
      )
      (1): PixelShuffle(upscale_factor=2)
    )
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): ConvLayer(
      (0): Conv2d(320, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (conv2): ConvLayer(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (relu): ReLU(inplace=True)
  )
  (dec1): UnetBlock(
    (shuf): PixelShuffle_ICNR(
      (0): ConvLayer(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU()
      )
      (1): PixelShuffle(upscale_factor=2)
    )
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): ConvLayer(
      (0): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (conv2): ConvLayer(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (relu): ReLU(inplace=True)
  )
  (fpn): FPN(
    (convs): ModuleList(
      (0): Sequential(
        (0): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): Sequential(
        (0): Conv2d(256, 32, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (2): Sequential(
        (0): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (3): Sequential(
        (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
  )
  (drop): Dropout2d(p=0.1, inplace=False)
  (final_conv): ConvLayer(
    (0): Conv2d(96, 2, kernel_size=(1, 1), stride=(1, 1))
  )
)
{% endraw %} {% raw %}
res = el.predict(f_names, 1,energy_ks=None)
Using cache found in /media/data/home/mag01ud/.cache/torch/hub/matjesg_deepflash2_master
{% endraw %} {% raw %}
for i, n in enumerate(f_names):
    print(n)
    fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(30,30))
    axs.imshow(res[3][i])
    plt.show()
sample_data_cFOS/images/0001.png
sample_data_cFOS/images/0004.png
sample_data_cFOS/images/0007.png
sample_data_cFOS/images/0008.png
sample_data_cFOS/images/0006.png
{% endraw %} {% raw %}
el.get_ensemble_results(f_names)
/media/data/anaconda3/envs/fastai/lib/python3.8/site-packages/fastai/callback/core.py:50: UserWarning: You are setting an attribute (__class__) that also exists in the learner, so you're not setting it in the learner but in the callback. Use `self.learn.__class__` otherwise.
  warn(f"You are setting an attribute ({name}) that also exists in the learner, so you're not setting it in the learner but in the callback. Use `self.learn.{name}` otherwise.")
Using cache found in /media/data/home/mag01ud/.cache/torch/hub/matjesg_deepflash2_master
{% endraw %} {% raw %}
el.df_ens
file model img_path res_path energy_max
0 0001.png unext50_deepflash2_ensemble sample_data_cFOS/images/0001.png .tmp/unext50_deepflash2_ensemble/0001.npz 2.616549977362156
1 0004.png unext50_deepflash2_ensemble sample_data_cFOS/images/0004.png .tmp/unext50_deepflash2_ensemble/0004.npz 3.180415459275246
2 0007.png unext50_deepflash2_ensemble sample_data_cFOS/images/0007.png .tmp/unext50_deepflash2_ensemble/0007.npz 2.9121269342303275
3 0008.png unext50_deepflash2_ensemble sample_data_cFOS/images/0008.png .tmp/unext50_deepflash2_ensemble/0008.npz 4.4454662615060805
4 0006.png unext50_deepflash2_ensemble sample_data_cFOS/images/0006.png .tmp/unext50_deepflash2_ensemble/0006.npz 4.255264647006989
{% endraw %} {% raw %}
el.df_models
file model_no model img_path res_path energy_max
0 0001.png 1 unext50_deepflash2_model-1.pth sample_data_cFOS/images/0001.png .tmp/unext50_deepflash2_model-1.pth/0001.npz 2.616549977362156
1 0004.png 1 unext50_deepflash2_model-1.pth sample_data_cFOS/images/0004.png .tmp/unext50_deepflash2_model-1.pth/0004.npz 3.180415459275246
2 0007.png 1 unext50_deepflash2_model-1.pth sample_data_cFOS/images/0007.png .tmp/unext50_deepflash2_model-1.pth/0007.npz 2.9121269342303275
3 0008.png 1 unext50_deepflash2_model-1.pth sample_data_cFOS/images/0008.png .tmp/unext50_deepflash2_model-1.pth/0008.npz 4.4454662615060805
4 0006.png 1 unext50_deepflash2_model-1.pth sample_data_cFOS/images/0006.png .tmp/unext50_deepflash2_model-1.pth/0006.npz 4.255264647006989
{% endraw %} {% raw %}
el.show_ensemble_results(file="0001.png")
{% endraw %} {% raw %}
model_arch = 'unet_deepflash2' #@param ["unet_deepflash2",  "unet_falk2019", "unet_ronnberger2015"]
{% endraw %}

Pretrained weights

  • Select 'new' to use an untrained model (no pretrained weights)
  • Or select pretraind model weights from dropdown menu
{% raw %}
pretrained_weights = "wue_cFOS" #@param ["new", "wue_cFOS", "wue_Parv", "wue_GFAP", "wue_GFP", "wue_OPN3"]
pre = False if pretrained_weights=="new" else True
n_channels = ds.get_data(max_n=1)[0].shape[-1]
model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=pre, dataset=pretrained_weights, n_classes=ds.c, in_channels=n_channels)
if pretrained_weights=="new": apply_init(model)
{% endraw %}

Setting model hyperparameters (optional)

  • mixed_precision_training: enables Mixed precision training
    • decreases memory usage and speed-up training
    • may effect model accuracy
  • batch_size: the number of samples that will be propagated through the network during one iteration
{% raw %}
mixed_precision_training = False #@param {type:"boolean"}
batch_size = 4 #@param {type:"slider", min:2, max:8, step:2}
loss_fn = WeightedSoftmaxCrossEntropy(axis=1)
cbs = [ElasticDeformCallback]
dls = DataLoaders.from_dsets(ds,ds, bs=batch_size)
if torch.cuda.is_available(): dls.cuda(), model.cuda()
learn = Learner(dls, model, wd=0.001, loss_func=loss_fn, cbs=cbs)
if mixed_precision_training: learn.to_fp16()
{% endraw %}
  • max_lr: The learning rate controls how quickly or slowly a neural network model learns.
    • We found that a maximum learning rate of 5e-4 (i.e., 0.0005) yielded the best results across experiments.
    • learning_rate_finder: Check only if you want use the Learning Rate Finder on your dataset.
{% raw %}
learning_rate_finder = False #@param {type:"boolean"}
if learning_rate_finder:
    lr_min,lr_steep = learn.lr_find()
    print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")
{% endraw %} {% raw %}
max_lr = 5e-4 #@param {type:"number"}
{% endraw %}

Model Training

Setting training parameters

  • n_models: Number of models to train.
    • If you're experimenting with parameters, try only one model first.
    • Depending on the data, ensembles should comprise 3-5 models.
    • _Note: Number of model affects the Train-validation-split._
{% raw %}
try:
    batch_size=batch_size
except:
    batch_size=4
    mixed_precision_training = False
    loss_fn = WeightedSoftmaxCrossEntropy(axis=1)
try:
    max_lr=max_lr
except:
    max_lr = 5e-4 

metrics = [Dice_f1(), Iou()]
n_models = 1 #@param {type:"slider", min:1, max:5, step:1}
print("Suggested epochs for 1000 iterations:", calc_iterations(len(ds), batch_size, n_models))
{% endraw %}
  • epochs: One epoch is when an entire (augemented) dataset is passed through the model for training.
    • Epochs need to be adusted depending on the size and number of images
    • We found that choosing the number of epochs such that the network parameters are update about 1000 times (iterations) leads to satiesfying results in most cases.
{% raw %}
epochs = 30 #@param {type:"slider", min:1, max:200, step:1}
{% endraw %}

Train models

{% raw %}
import ipywidgets as widgets
{% endraw %} {% raw %}
train_label = widgets.Label(f'test')
train_out = widgets.Output()

train_list = [train_label, train_out]
train_box = widgets.VBox(train_list)

param_box = widgets.VBox()
finder_box = widgets.VBox()

tab_contents = ['Train', 'Parameters', 'LR-Finder']
children = [train_box, param_box, finder_box]
tab = widgets.Tab()
tab.children = children
for title, (index, _) in zip(tab_contents, enumerate(tab.children)):
    tab.set_title(index, title)
tab
{% endraw %} {% raw %}
learn.recorder.plot_metrics()
{% endraw %} {% raw %}
from IPython.display import display
button = widgets.Button(description="Click Me!")
output = widgets.Output()

display(button, output)

def on_button_clicked(b):
    with output:
        print("Button clicked.")
        with learn.no_logging(): learn.fit_one_cycle(3, max_lr)

button.on_click(on_button_clicked)
{% endraw %} {% raw %}
widgets.AppLayout(header=widgets.Text(description="Test"))
{% endraw %} {% raw %}
learn = Learner(dls, model, metrics = metrics, wd=0.001, loss_func=loss_fn, cbs=cbs)
{% endraw %} {% raw %}
learn = Learner(dls, model, metrics = metrics, loss_func=loss_fn)
{% endraw %} {% raw %}
learn.no_logging()
{% endraw %} {% raw %}
o_test = widgets.Output()
o_test
{% endraw %} {% raw %}
with o_test: print("ss")
{% endraw %} {% raw %}
with o_test:
    with progress_disabled():
        for x in progress_bar(range(100)):
            print(x)
{% endraw %} {% raw %}
with train_out: learn.fit_one_cycle(3, max_lr)
{% endraw %} {% raw %}
learn.fit_one_cycle(3, max_lr)
{% endraw %} {% raw %}
learn.fit_one_cycle
{% endraw %} {% raw %}
list(kf.split(f_names))[0]
{% endraw %} {% raw %}
kf = KFold(n_splits=max(n_models,2))
model_path = path/'models'
model_path.mkdir(parents=True, exist_ok=True)
res, res_mc = {}, {}
fold = 0
for train_idx, val_idx in kf.split(f_names):
    fold += 1
    name = f'model{fold}'
    print('Train', name)
    if n_models==1:
        files_train, files_val = train_test_split(f_names)
    else:
        files_train, files_val = f_names[train_idx], f_names[val_idx]
    print(f'Validation Images: {files_val}')    
    train_ds = RandomTileDataset(files_train, label_fn, **mw_dict)
    valid_ds = TileDataset(files_val, label_fn, **mw_dict)
    
    dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=batch_size)
    dls_valid = DataLoaders.from_dsets(valid_ds, batch_size=batch_size ,shuffle=False, drop_last=False)
    model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=pre, 
                           dataset=pretrained_weights, n_classes=ds.c, in_channels=n_channels)
    if pretrained_weights=="new": apply_init(model)
    if torch.cuda.is_available(): dls.cuda(), model.cuda(), dls_valid.cuda()
    
    cbs = [SaveModelCallback(monitor='iou'), ElasticDeformCallback, ShowGraphCallback]
    metrics = [Dice_f1(), Iou()]
    learn = Learner(dls, model, metrics = metrics, wd=0.001, loss_func=loss_fn, cbs=cbs)
    if mixed_precision_training: learn.to_fp16()
    assert False
    learn.fit_one_cycle(epochs, max_lr)
    # save_model(model_path/f'{name}.pth', learn.model, opt=None)
    torch.save(learn.model.state_dict(), model_path/f'{name}.pth', _use_new_zipfile_serialization=False)
    
    smxs, segs, _ = learn.predict_tiles(dl=dls_valid.train)    
    smxs_mc, segs_mc, std = learn.predict_tiles(dl=dls_valid.train, mc_dropout=True, n_times=10)
    
    for i, file in enumerate(files_val):
        res[(name, file)] = smxs[i], segs[i]
        res_mc[(name, file)] = smxs_mc[i], segs_mc[i], std[i]
    
    if n_models==1:
        break
{% endraw %}

Validate models

Here you can validate your models. To avoid information leakage, only predictions on the respective models' validation set are made.

{% raw %}
pred_dir = 'val_preds' #@param {type:"string"}
pred_path = path/pred_dir/'ensemble'
pred_path.mkdir(parents=True, exist_ok=True)
uncertainty_dir = 'val_uncertainties' #@param {type:"string"}
uncertainty_path = path/uncertainty_dir/'ensemble'
uncertainty_path.mkdir(parents=True, exist_ok=True)
result_path = path/'results'
result_path.mkdir(exist_ok=True)

#@markdown Define `filetype` to save the predictions and uncertainties. All common [file formats](https://imageio.readthedocs.io/en/stable/formats.html) are supported.
filetype = 'png' #@param {type:"string"}
{% endraw %} {% raw %}
res_list = []
for model_number in range(1,n_models+1):
    model_name = f'model{model_number}'
    val_files = [f for mod , f in res.keys() if mod == model_name]
    print(f'Validating {model_name}')
    pred_path = path/pred_dir/model_name
    pred_path.mkdir(parents=True, exist_ok=True)
    uncertainty_path = path/uncertainty_dir/model_name
    uncertainty_path.mkdir(parents=True, exist_ok=True)
    for file in val_files:
        img = ds.get_data(file)[0]
        msk = ds.get_data(file, mask=True)[0]
        pred = res[(model_name,file)][1]
        pred_std = res_mc[(model_name,file)][2][...,0]
        df_tmp = pd.Series({'file' : file.name,
                            'model' : model_name,
                            'iou': iou(msk, pred),
                            'entropy': entropy(pred_std, axis=None)})
        plot_results(img, msk, pred, pred_std, df=df_tmp)
        res_list.append(df_tmp)
        imageio.imsave(pred_path/f'{file.stem}_pred.{filetype}', pred.astype(np.uint8) if np.max(pred)>1 else pred.astype(np.uint8)*255)
        imageio.imsave(uncertainty_path/f'{file.stem}_uncertainty.{filetype}', pred_std.astype(np.uint8)*255)
df_res = pd.DataFrame(res_list)
df_res.to_csv(result_path/f'val_results.csv', index=False)
{% endraw %} {% raw %}
x = None
not x
{% endraw %}

Download Section

  • The models will always be the last version trained in section Model Training
  • To download validation predictions and uncertainties, you first need to execute section Validate models.

Note: If you're connected to Google Drive, the models are automatically saved to your drive.

{% raw %}
model_number = "1" #@param ["1", "2", "3", "4", "5"]
model_path = path/'models'/f'model{model_number}.pth'
try:
    files.download(model_path)
except:
    print("Warning: File download only works on Google Colab.")
    print(f"Models are saved at {model_path.parent}")
    pass
{% endraw %} {% raw %}
out_name = 'val_predictions'
shutil.make_archive(path/out_name, 'zip', path/pred_dir)
try:
    files.download(path/f'{out_name}.zip')
except:
    print("Warning: File download only works on Google Colab.")
    pass
{% endraw %} {% raw %}
out_name = 'val_uncertainties'
shutil.make_archive(path/out_name, 'zip', path/uncertainty_dir)
try:
    files.download(path/f'{out_name}.zip')
except:
    print("Warning: File download only works on Google Colab.")
    pass
{% endraw %} {% raw %}
try:
    files.download(result_path/f'val_results.csv')
except:
    print("Warning: File download only works on Google Colab.")
    pass
{% endraw %}